{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## New Classes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The **LAN extension (HDDM >= 0.9.0)**, provides three new classes which are LAN-enabled versions of the respective classes in base HDDM.\n",
    "These new classes are, \n",
    "\n",
    "- The `HDDMnn()` class\n",
    "- The `HDDMnnStimCoding()` class\n",
    "- The `HDDMnnRegressor()` class"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The usage mirrors what you are used to from standard `HDDM` equivalents. \n",
    "\n",
    "What changes is that you now use the `model` argument to specify one of the models you find listed in the `hddm.model_config.model_config` dictionary (you can also provide a custom model, for which you should look into the respective section in this documentation).\n",
    "\n",
    "Moreover, you have to be a little more careful when specifying the `include` argument, since the ability to use new models comes with new parameters. To help get started here, the `hddm.model_config.model_config` dictionary provides you a `hddm_include` key for *every* model-specific sub-dictionary. This let's you fit all parameters of a given model. To keep some parameters fixed, remove them respectively from the resulting list. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Short example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import hddm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using default priors: Uninformative\n"
     ]
    }
   ],
   "source": [
    "model = \"angle\"\n",
    "cavanagh_data = hddm.load_csv(hddm.__path__[0] + \"/examples/cavanagh_theta_nn.csv\")\n",
    "model_ = hddm.HDDMnn(\n",
    "    cavanagh_data,\n",
    "    model=model,\n",
    "    include=hddm.model_config.model_config[model][\"hddm_include\"],\n",
    "    is_group_model=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " [-----------------100%-----------------] 1000 of 1000 complete in 260.3 sec"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<pymc.MCMC.MCMC at 0x13e449650>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_.sample(1000, burn=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>v</th>\n",
       "      <th>a</th>\n",
       "      <th>z_trans</th>\n",
       "      <th>t</th>\n",
       "      <th>theta</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.370402</td>\n",
       "      <td>1.325747</td>\n",
       "      <td>0.023242</td>\n",
       "      <td>0.284196</td>\n",
       "      <td>0.253870</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.338917</td>\n",
       "      <td>1.328545</td>\n",
       "      <td>0.062895</td>\n",
       "      <td>0.283047</td>\n",
       "      <td>0.248485</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.386179</td>\n",
       "      <td>1.321476</td>\n",
       "      <td>0.054727</td>\n",
       "      <td>0.285712</td>\n",
       "      <td>0.250671</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.387484</td>\n",
       "      <td>1.323711</td>\n",
       "      <td>-0.019109</td>\n",
       "      <td>0.274198</td>\n",
       "      <td>0.253445</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.370557</td>\n",
       "      <td>1.323342</td>\n",
       "      <td>0.015675</td>\n",
       "      <td>0.277691</td>\n",
       "      <td>0.255681</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>795</th>\n",
       "      <td>0.325748</td>\n",
       "      <td>1.331846</td>\n",
       "      <td>0.113685</td>\n",
       "      <td>0.270311</td>\n",
       "      <td>0.252461</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>796</th>\n",
       "      <td>0.337564</td>\n",
       "      <td>1.315446</td>\n",
       "      <td>0.111898</td>\n",
       "      <td>0.286141</td>\n",
       "      <td>0.252236</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>797</th>\n",
       "      <td>0.387142</td>\n",
       "      <td>1.309284</td>\n",
       "      <td>0.036839</td>\n",
       "      <td>0.286663</td>\n",
       "      <td>0.238878</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>798</th>\n",
       "      <td>0.388073</td>\n",
       "      <td>1.313791</td>\n",
       "      <td>-0.013604</td>\n",
       "      <td>0.271768</td>\n",
       "      <td>0.235831</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>799</th>\n",
       "      <td>0.397477</td>\n",
       "      <td>1.314008</td>\n",
       "      <td>-0.007186</td>\n",
       "      <td>0.276948</td>\n",
       "      <td>0.242729</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>800 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "            v         a   z_trans         t     theta\n",
       "0    0.370402  1.325747  0.023242  0.284196  0.253870\n",
       "1    0.338917  1.328545  0.062895  0.283047  0.248485\n",
       "2    0.386179  1.321476  0.054727  0.285712  0.250671\n",
       "3    0.387484  1.323711 -0.019109  0.274198  0.253445\n",
       "4    0.370557  1.323342  0.015675  0.277691  0.255681\n",
       "..        ...       ...       ...       ...       ...\n",
       "795  0.325748  1.331846  0.113685  0.270311  0.252461\n",
       "796  0.337564  1.315446  0.111898  0.286141  0.252236\n",
       "797  0.387142  1.309284  0.036839  0.286663  0.238878\n",
       "798  0.388073  1.313791 -0.013604  0.271768  0.235831\n",
       "799  0.397477  1.314008 -0.007186  0.276948  0.242729\n",
       "\n",
       "[800 rows x 5 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_.get_traces()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>2.5q</th>\n",
       "      <th>25q</th>\n",
       "      <th>50q</th>\n",
       "      <th>75q</th>\n",
       "      <th>97.5q</th>\n",
       "      <th>mc err</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>v</th>\n",
       "      <td>0.369154</td>\n",
       "      <td>0.0207375</td>\n",
       "      <td>0.329893</td>\n",
       "      <td>0.355813</td>\n",
       "      <td>0.369495</td>\n",
       "      <td>0.382592</td>\n",
       "      <td>0.409568</td>\n",
       "      <td>0.00111918</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>a</th>\n",
       "      <td>1.31224</td>\n",
       "      <td>0.0212032</td>\n",
       "      <td>1.26826</td>\n",
       "      <td>1.29879</td>\n",
       "      <td>1.31332</td>\n",
       "      <td>1.32755</td>\n",
       "      <td>1.3514</td>\n",
       "      <td>0.00180178</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>z</th>\n",
       "      <td>0.504951</td>\n",
       "      <td>0.00604908</td>\n",
       "      <td>0.493251</td>\n",
       "      <td>0.500775</td>\n",
       "      <td>0.504934</td>\n",
       "      <td>0.509041</td>\n",
       "      <td>0.517023</td>\n",
       "      <td>0.000311986</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>t</th>\n",
       "      <td>0.283719</td>\n",
       "      <td>0.00943542</td>\n",
       "      <td>0.265774</td>\n",
       "      <td>0.277639</td>\n",
       "      <td>0.283707</td>\n",
       "      <td>0.290191</td>\n",
       "      <td>0.302331</td>\n",
       "      <td>0.00070058</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>theta</th>\n",
       "      <td>0.242432</td>\n",
       "      <td>0.0127552</td>\n",
       "      <td>0.216824</td>\n",
       "      <td>0.234284</td>\n",
       "      <td>0.242875</td>\n",
       "      <td>0.251645</td>\n",
       "      <td>0.265587</td>\n",
       "      <td>0.00103379</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           mean         std      2.5q       25q       50q       75q     97.5q  \\\n",
       "v      0.369154   0.0207375  0.329893  0.355813  0.369495  0.382592  0.409568   \n",
       "a       1.31224   0.0212032   1.26826   1.29879   1.31332   1.32755    1.3514   \n",
       "z      0.504951  0.00604908  0.493251  0.500775  0.504934  0.509041  0.517023   \n",
       "t      0.283719  0.00943542  0.265774  0.277639  0.283707  0.290191  0.302331   \n",
       "theta  0.242432   0.0127552  0.216824  0.234284  0.242875  0.251645  0.265587   \n",
       "\n",
       "            mc err  \n",
       "v       0.00111918  \n",
       "a       0.00180178  \n",
       "z      0.000311986  \n",
       "t       0.00070058  \n",
       "theta   0.00103379  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_.gen_stats()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hddmnn_tutorial",
   "language": "python",
   "name": "hddmnn_tutorial"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
